import re

class Vulnerability:
    def __init__(self, source_map, pcs):
        self.source_map = source_map
        self.pcs = self._rm_general_false_positives(pcs)
        if source_map:
            self.warnings = self._warnings()

    def is_vulnerable(self):
        return bool(self.pcs)

    def get_warnings(self):
        return self.warnings

    def _rm_general_false_positives(self, pcs):
        new_pcs = pcs
        if self.source_map:
            new_pcs = self._rm_pcs_having_no_source_code(new_pcs)
            new_pcs = self._reduce_pcs_having_the_same_pos(new_pcs)
        return new_pcs

    def _rm_pcs_having_no_source_code(self, pcs):
        return [pc for pc in pcs if self.source_map.get_source_code(pc)]

    def _reduce_pcs_having_the_same_pos(self, pcs):
        d = {}
        for pc in pcs:
            # "{'begin': 417, 'end': 435, 'name': 'CALL'}"
            pos = str(self.source_map.instr_positions[pc])
            if pos not in d:
                d[pos] = pc
        # "{'begin': 417, 'end': 435, 'name': 'CALL'}":314 "{'begin': 585, 'end': 608, 'name': 'CALL'}":495
        return d.values()

    def _warnings(self):
        warnings = []
        for pc in self.pcs:
            source_code = self.source_map.get_source_code(pc)
            if not source_code:
                continue

            source_code = self.source_map.get_buggy_line(pc)
            s = self._warning_content(pc, source_code)
            if s:
                warnings.append(s)
        return warnings

    def _warning_content(self, pc, source_code):
        new_line_idx = source_code.find('\n')
        source_code = source_code.split('\n', 1)[0]
        location = self.source_map.get_location(pc)

        source = re.sub(self.source_map.root_path, '', self.source_map.get_filename())
        line = location['begin']['line'] + 1
        column = location['begin']['column'] + 1
        s = '%s:%s:%s: Warning: %s.\n' % (source, line, column, self.name)
        s += source_code
        if new_line_idx != -1:
            s += '\n' + self._leading_spaces(source_code) + '^\n'
            s += 'Spanning multiple lines.'
        return s

    def _leading_spaces(self, s):
        stripped_s = s.lstrip('[ \t]')
        len_of_leading_spaces = len(s) - len(stripped_s)
        return s[0:len_of_leading_spaces]

    def __str__(self):
        s = ''
        for warning in self.warnings:
            s += '\n' + warning
        return s.lstrip('\n')

class CallStack(Vulnerability):
    def __init__(self, source_map, pcs, calls_affect_state):
        self.source_map = source_map
        self.pcs = self._rm_false_positives(pcs, calls_affect_state)
        if source_map:
            self.name = 'Callstack Depth Attack Vulnerability'
            self.warnings = Vulnerability._warnings(self)

    def _rm_false_positives(self, pcs, calls_affect_state):
        new_pcs = Vulnerability._rm_general_false_positives(self, pcs)
        return self._rm_pcs_not_affect_state(new_pcs, calls_affect_state)

    def _rm_pcs_not_affect_state(self, pcs, calls_affect_state):
        new_pcs = []
        for pc in pcs:
            if pc in calls_affect_state and calls_affect_state[pc] or pc not in calls_affect_state:
                new_pcs.append(pc)
        return new_pcs

class TimeDependency(Vulnerability):
    def __init__(self, source_map, pcs):
        self.name = 'Timestamp Dependency'
        Vulnerability.__init__(self, source_map, pcs)

class Reentrancy(Vulnerability):
    def __init__(self, source_map, pcs):
        self.name = 'Re-Entrancy Vulnerability'
        Vulnerability.__init__(self, source_map, pcs)

class MoneyConcurrency(Vulnerability):
    def __init__(self, source_map, flows):
        self.name = 'Transaction-Ordering Dependency'
        self.source_map = source_map
        self.flows = flows
        if source_map:
            self.warnings_of_flows = self._warnings_of_flows()

    def is_vulnerable(self):
        return bool(self.flows)

    def get_warnings_of_flows(self):
        return self.warnings_of_flows

    def _warnings_of_flows(self):
        warnings_of_flows = []
        for pcs in self.flows:
            s = ''
            pcs = Vulnerability._rm_general_false_positives(self, pcs)
            warnings = []
            for pc in pcs:
                source_code = self.source_map.get_source_code(pc)
                if not source_code:
                    continue

                source_code = self.source_map.get_buggy_line(pc)
                s = Vulnerability._warning_content(self, pc, source_code)
                if s:
                    warnings.append(s)
            warnings_of_flows.append(warnings)
        return warnings_of_flows

    def __str__(self):
        s = ''
        for i, warnings in enumerate(self.warnings_of_flows):
            if i != 0:
                s += '\n'
            s += 'Flow' + str(i + 1)
            for warning in warnings:
                s += '\n' + warning
        return s

class AssertionFailure(Vulnerability):
    def __init__(self, source_map, assertions):
        self.source_map = source_map
        self.assertions = self._reduce_pcs_having_the_same_pos(assertions)
        self.name = ' '.join(re.findall('[A-Z][a-z]+', self.__class__.__name__))
        if self.name == 'AssertionFailure' and not source_map:
            raise Exception("source_map attribute can't be None")
        self.warnings = self._warnings()

    def is_vulnerable(self):
        return bool(self.assertions)

    def _reduce_pcs_having_the_same_pos(self, assertions):
        d = {}
        for asrt in assertions:
            pos = str(self.source_map.instr_positions[asrt.pc])
            if pos not in d:
                d[pos] = asrt
        return d.values()

    def _warnings(self):
        warnings = []
        for asrt in self.assertions:
            source_code = self.source_map.get_buggy_line(asrt.pc)
            s = Vulnerability._warning_content(self, asrt.pc, source_code)

            model = ''
            for variable in asrt.model.decls():
                var_name = str(variable)
                if len(var_name.split('-')) > 2:
                    var_name = var_name.split('-')[2]
                if self.source_map.get_parameter_or_state_var(var_name):
                    model += '\n    ' + var_name + ' = ' + str(asrt.model[variable])
            if model:
                model = "\n%s occurs if:%s" % (self.name, model)
                s += model
            if s:
                warnings.append(s)
        return warnings

class IntegerUnderflow(AssertionFailure):
    pass

class IntegerOverflow(AssertionFailure):
    pass

class ParityMultisigBug2(Vulnerability):
    def __init__(self, source_map):
        self.source_map = source_map
        self.pairs = self._get_contracts_containing_selfdestruct_opcode()
        self.warnings = self._warnings()

    def is_vulnerable(self):
        return bool(self.pairs)

    def _warnings(self):
        warnings = []
        for pair in self.pairs:
            source_code = self.source_map.get_buggy_line_from_src(pair[1])
            new_line_idx = source_code.find('\n')
            source_code = source_code.split('\n', 1)[0]
            location = self.source_map.get_location_from_src(pair[1])

            source = re.sub(self.source_map.root_path, '', self.source_map.get_filename())
            line = location['begin']['line'] + 1
            column = location['begin']['column'] + 1

            s = '%s:%s:%s: Warning: Parity Multisig Bug 2.\n' % (source, line, column)
            s += source_code
            if new_line_idx != -1:
                s += '\n' + Vulnerability._leading_spaces(self, source_code) + '^\n'
                s += 'Spanning multiple lines.'
            warnings.append(s)
        return warnings

    def _get_contracts_containing_selfdestruct_opcode(self):
        ret = []
        for pair in self.source_map.callee_src_pairs:
            disasm_data = open(pair[0] + ".evm.disasm").read()
            regex = re.compile("SELFDESTRUCT|SUICIDE")
            if regex.search(disasm_data):
                ret.append(pair)
        return ret

